{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wT7Qh_c5M3hW"
      },
      "outputs": [],
      "source": [
        "# 安装必要的Python库\n",
        "# unsloth: 一个能将模型微调速度提升2倍并减少显存占用的库\n",
        "# vllm: 一个用于快速高效LLM推理和服务的库，Unsloth在GRPO中会使用它来加速生成过程\n",
        "pip install unsloth==2025.3.19 vllm==0.8.2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "DkIvEkIIkEyB"
      },
      "outputs": [],
      "source": [
        "from unsloth import FastLanguageModel\n",
        "import torch\n",
        "\n",
        "# 设置模型的最大序列长度。可以根据需要增加此值以支持更长的推理和上下文。\n",
        "max_seq_length = 2048\n",
        "\n",
        "# 设置LoRA的秩（rank）。秩越高，模型可能变得更“智能”，但训练和推理的速度会变慢，显存占用也会增加。\n",
        "# LoRA是一种参数高效微调技术，通过训练小型的“适配器”矩阵来调整模型，而不是训练全部参数。\n",
        "lora_rank = 32\n",
        "\n",
        "# 从HuggingFace Hub加载预训练模型和分词器\n",
        "model, tokenizer = FastLanguageModel.from_pretrained(\n",
        "    # 指定要加载的预训练模型名称。可以是HuggingFace官方模型、本地模型或Unsloth优化后的模型。\n",
        "    model_name=\"Qwen/Qwen3-8B\",\n",
        "    # 设置模型的最大序列长度，与上面定义的变量一致。\n",
        "    max_seq_length=max_seq_length,\n",
        "    # 是否以4位精度加载模型。对于使用LoRA进行16位浮点数训练，此项应设置为False。\n",
        "    load_in_4bit=False,\n",
        "    # 是否启用vLLM进行快速推理。GRPO训练中生成多个响应时，此选项能显著提速。\n",
        "    fast_inference=True,\n",
        "    # 设置LoRA的最大秩，与上面定义的变量一致。\n",
        "    max_lora_rank=lora_rank,\n",
        "    # 设置GPU显存的使用率。如果遇到显存不足（OOM）的错误，可以适当降低此值。\n",
        "    gpu_memory_utilization=0.7,\n",
        ")\n",
        "\n",
        "# 为模型添加PEFT（Parameter-Efficient Fine-Tuning，参数高效微调）配置，这里使用LoRA。\n",
        "model = FastLanguageModel.get_peft_model(\n",
        "    model,\n",
        "    # LoRA的秩(r)，选择任何大于0的数字。建议值为 8, 16, 32, 64, 128。\n",
        "    r=lora_rank,\n",
        "    # target_modules 是一个列表，包含要应用LoRA技术的目标模块名称。\n",
        "    # 通常我们会选择注意力机制中的投影层和前馈网络中的层。\n",
        "    target_modules=[\n",
        "        \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",  # 注意力机制中的查询、键、值、输出投影\n",
        "        \"gate_proj\", \"up_proj\", \"down_proj\",     # 前馈网络中的门控、上行和下行投影\n",
        "    ],\n",
        "    # LoRA的alpha参数，通常设置为秩(r)的2倍，这是一种常见的做法，有助于稳定训练。\n",
        "    lora_alpha=lora_rank * 2,\n",
        "    # 是否使用梯度检查点技术。'unsloth'表示使用Unsloth的优化版本，可以显著减少训练时的显存占用。\n",
        "    use_gradient_checkpointing=\"unsloth\",\n",
        "    # 设置随机种子，以确保实验结果的可复现性。\n",
        "    random_state=3407,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6UjowCbT-cFz"
      },
      "outputs": [],
      "source": [
        "# 定义一些特殊的字符串标记，用于指导模型生成我们想要的格式。\n",
        "# 这是一种“格式提示”或“模板化”的方法，让模型学会生成带有思考过程和最终答案的结构化输出。\n",
        "\n",
        "# <start_working_out> 和 <end_working_out> 用于包裹模型的“思考过程”或“解题步骤”。\n",
        "reasoning_start = \"<start_working_out>\"\n",
        "reasoning_end   = \"<end_working_out>\"\n",
        "\n",
        "# <SOLUTION> 和 </SOLUTION> 用于包裹模型给出的最终、简洁的答案。\n",
        "solution_start  = \"<SOLUTION>\"\n",
        "solution_end    = \"</SOLUTION>\"\n",
        "\n",
        "# 定义系统提示（System Prompt）。这个提示会作为对话的初始指令，告诉模型它的角色和任务。\n",
        "# 在这里，我们要求模型先进行思考，将过程放在<start_working_out>和<end_working_out>之间，\n",
        "# 然后再将最终答案放在<SOLUTION>和</SOLUTION>之间。\n",
        "system_prompt = \\\n",
        "f\"\"\"You are given a problem.\n",
        "Think about the problem and provide your working out.\n",
        "Place it between {reasoning_start} and {reasoning_end}.\n",
        "Then, provide your solution between {solution_start}{solution_end}\"\"\"\n",
        "\n",
        "# 打印系统提示，查看其内容。\n",
        "system_prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y3fF9gMujY02"
      },
      "outputs": [],
      "source": [
        "# 创建一个自定义的聊天模板（Chat Template）。\n",
        "# 聊天模板使用Jinja2语法，定义了如何将多轮对话（包含system, user, assistant等角色）格式化为单个字符串，\n",
        "# 以便输入给模型进行训练或推理。\n",
        "chat_template = \\\n",
        "    \"{% if messages[0]['role'] == 'system' %}\"\\\n",
        "        \"{{ messages[0]['content'] + eos_token }}\"\\\n",
        "        \"{% set loop_messages = messages[1:] %}\"\\\n",
        "    \"{% else %}\"\\\n",
        "        \"{{ '{system_prompt}' + eos_token }}\"\\\n",
        "        \"{% set loop_messages = messages %}\"\\\n",
        "    \"{% endif %}\"\\\n",
        "    \"{% for message in loop_messages %}\"\\\n",
        "        \"{% if message['role'] == 'user' %}\"\\\n",
        "            \"{{ message['content'] }}\"\\\n",
        "        \"{% elif message['role'] == 'assistant' %}\"\\\n",
        "            \"{{ message['content'] + eos_token }}\"\\\n",
        "        \"{% endif %}\"\\\n",
        "    \"{% endfor %}\"\\\n",
        "    \"{% if add_generation_prompt %}{{ '{reasoning_start}' }}\"\\\n",
        "    \"{% endif %}\"\n",
        "\n",
        "# 将模板中的占位符替换为我们之前定义的特定字符串。\n",
        "# 这样做可以使模板适应我们自定义的格式要求。\n",
        "chat_template = chat_template\\\n",
        "    .replace(\"'{system_prompt}'\",   f\"'{system_prompt}'\")\\\n",
        "    .replace(\"'{reasoning_start}'\", f\"'{reasoning_start}'\")\n",
        "\n",
        "# 将我们创建的自定义聊天模板赋值给分词器（tokenizer）的chat_template属性。\n",
        "# 这样，之后调用tokenizer.apply_chat_template时，就会使用这个新模板。\n",
        "tokenizer.chat_template = chat_template"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BciEDYSSYFNj"
      },
      "outputs": [],
      "source": [
        "# 使用分词器的apply_chat_template方法来测试我们自定义的聊天模板。\n",
        "# 输入是一个包含多轮对话的列表，每轮对话是一个字典，包含'role'和'content'。\n",
        "# tokenize = False 表示我们只想看到格式化后的字符串，而不是token ID。\n",
        "# add_generation_prompt = True 会在末尾添加生成提示，这里是我们定义的'{reasoning_start}'，\n",
        "# 引导模型从“思考过程”开始生成回答。\n",
        "tokenizer.apply_chat_template([\n",
        "    {\"role\" : \"user\", \"content\" : \"What is 1+1?\"},\n",
        "    {\"role\" : \"assistant\", \"content\" : f\"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}\"},\n",
        "    {\"role\" : \"user\", \"content\" : \"What is 2+2?\"},\n",
        "], tokenize = False, add_generation_prompt = True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AXxM2lStVIkd"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "\n",
        "# 从HuggingFace Hub加载'unsloth/OpenMathReasoning-mini'数据集，并选择'cot' (Chain of Thought)子集。\n",
        "dataset = load_dataset(\"unsloth/OpenMathReasoning-mini\", split = \"cot\")\n",
        "\n",
        "# 将数据集转换为pandas DataFrame，并只保留我们需要的列。\n",
        "dataset = dataset.to_pandas()[\n",
        "    [\"expected_answer\", \"problem\", \"generated_solution\"]\n",
        "]\n",
        "\n",
        "# 数据清洗：过滤数据集，只保留'expected_answer'列是数值的行。\n",
        "# 首先，尝试将'expected_answer'列转换为数值类型，无法转换的将变为NaN (Not a Number)。\n",
        "is_number = pd.to_numeric(pd.Series(dataset[\"expected_answer\"]), errors = \"coerce\").notnull()\n",
        "# 然后，根据is_number布尔序列，筛选出值为数值的行。\n",
        "dataset = dataset.iloc[np.where(is_number)[0]]\n",
        "\n",
        "# 显示处理后的数据集，查看其结构和内容。\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z9ydcV_Abfi6"
      },
      "outputs": [],
      "source": [
        "# 定义一个函数，用于将数据集中的每一行格式化为我们需要的对话格式。\n",
        "def format_dataset(x):\n",
        "    # 提取原始数据列\n",
        "    expected_answer = x[\"expected_answer\"] # 期望的答案\n",
        "    problem = x[\"problem\"] # 问题本身\n",
        "\n",
        "    # 清理模型生成的解决方案，移除其中已有的<think>和</think>标签。\n",
        "    thoughts = x[\"generated_solution\"]\n",
        "    thoughts = thoughts.replace(\"<think>\", \"\").replace(\"</think>\", \"\")\n",
        "\n",
        "    # 移除思考过程文本两端的空白字符（如换行符和空格）。\n",
        "    thoughts = thoughts.strip()\n",
        "\n",
        "    # 使用我们自定义的格式，将思考过程和最终答案组合成一个字符串。\n",
        "    final_prompt = \\\n",
        "        reasoning_start + thoughts + reasoning_end + \\\n",
        "        solution_start + expected_answer + solution_end\n",
        "    \n",
        "    # 返回一个列表，其中包含符合聊天模板输入格式的字典。\n",
        "    return [\n",
        "        {\"role\" : \"system\",    \"content\" : system_prompt},\n",
        "        {\"role\" : \"user\",      \"content\" : problem},\n",
        "        {\"role\" : \"assistant\", \"content\" : final_prompt},\n",
        "    ]\n",
        "\n",
        "# 使用.apply方法将format_dataset函数应用到DataFrame的每一行，并将结果存储在新列'Messages'中。\n",
        "dataset[\"Messages\"] = dataset.apply(format_dataset, axis = 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LTdXBKcslhRH"
      },
      "outputs": [],
      "source": [
        "# 再次调用apply_chat_template，这次应用于处理后的数据集的第一行，\n",
        "# 以验证格式化是否正确。\n",
        "tokenizer.apply_chat_template(dataset[\"Messages\"][0], tokenize = False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MBHFlRbae9_s"
      },
      "outputs": [],
      "source": [
        "# 计算每条格式化后的消息被分词器处理后得到的token数量，并将其存储在新列'N'中。\n",
        "dataset[\"N\"] = dataset[\"Messages\"].apply(lambda x: len(tokenizer.apply_chat_template(x)))\n",
        "\n",
        "# 过滤数据集，只保留token数量小于等于max_seq_length/2的样本。\n",
        "# 这样做是为了确保在训练时，输入（prompt）不会太长，能给模型的生成（completion）留下足够的空间，\n",
        "# 避免因超出最大序列长度而被截断。\n",
        "dataset = dataset.loc[dataset[\"N\"] <= max_seq_length/2].copy()\n",
        "\n",
        "# 显示过滤后数据集的形状（行数，列数）。\n",
        "dataset.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3rgdtiV_f5hx"
      },
      "outputs": [],
      "source": [
        "from datasets import Dataset\n",
        "\n",
        "# 将所有格式化好的消息（Messages列）应用聊天模板，生成最终的训练文本，并存储在'text'列中。\n",
        "dataset[\"text\"] = tokenizer.apply_chat_template(dataset[\"Messages\"].values.tolist(), tokenize = False)\n",
        "\n",
        "# 将pandas DataFrame转换回HuggingFace的Dataset对象，这是SFTTrainer所要求的格式。\n",
        "dataset = Dataset.from_pandas(dataset)\n",
        "\n",
        "# 显示最终准备好的Dataset对象信息。\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "woYi0SSygpqp"
      },
      "outputs": [],
      "source": [
        "from trl import SFTTrainer, SFTConfig\n",
        "\n",
        "# 实例化SFTTrainer，用于进行有监督微调。\n",
        "trainer = SFTTrainer(\n",
        "    model = model, # 已经加载并添加了LoRA适配器的模型\n",
        "    tokenizer = tokenizer, # 配套的分词器\n",
        "    train_dataset = dataset, # 训练数据集\n",
        "    args = SFTConfig( # SFT训练的配置参数\n",
        "        # 指定数据集中包含完整训练文本的列名。\n",
        "        dataset_text_field = \"text\",\n",
        "        # 每个设备上的训练批次大小。\n",
        "        per_device_train_batch_size = 1,\n",
        "        # 梯度累积步数。实际的批次大小 = per_device_train_batch_size * gradient_accumulation_steps。\n",
        "        # 用于在显存有限的情况下模拟更大的批次大小。\n",
        "        gradient_accumulation_steps = 1,\n",
        "        # 预热步数，在训练初期使用较小的学习率，有助于稳定训练。\n",
        "        warmup_steps = 5,\n",
        "        # 训练的总轮数（epochs）。\n",
        "        num_train_epochs = 2,\n",
        "        # 学习率。对于较长时间的训练，可以适当减小此值，例如2e-5。\n",
        "        learning_rate = 2e-4,\n",
        "        # 每隔多少步记录一次日志。\n",
        "        logging_steps = 5,\n",
        "        # 优化器类型。'adamw_8bit'是一种内存效率更高的AdamW优化器。\n",
        "        optim = \"adamw_8bit\",\n",
        "        # 权重衰减，一种正则化技术，防止过拟合。\n",
        "        weight_decay = 0.01,\n",
        "        # 学习率调度器类型，'linear'表示学习率会从初始值线性衰减到0。\n",
        "        lr_scheduler_type = \"linear\",\n",
        "        # 随机种子，保证训练过程的可复现性。\n",
        "        seed = 3407,\n",
        "        # 将训练日志报告到指定的平台，如'wandb'（Weights & Biases）或'tensorboard'。\n",
        "        # 'swanlab'是一个类似的实验跟踪工具。\n",
        "        report_to = \"swanlab\",\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l4-2v_bLhZuE"
      },
      "outputs": [],
      "source": [
        "# 开始SFT训练过程。\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9HJxrS76h3Ds"
      },
      "outputs": [],
      "source": [
        "# SFT训练后进行推理测试\n",
        "# 从数据集中取出一条样本的前两部分（system和user消息）作为推理的输入。\n",
        "text = tokenizer.apply_chat_template(\n",
        "    dataset[0][\"Messages\"][:2],\n",
        "    tokenize = False,\n",
        "    add_generation_prompt = True, # 必须添加此项以触发模型的生成\n",
        ")\n",
        "\n",
        "from transformers import TextStreamer\n",
        "\n",
        "# 使用model.generate进行推理\n",
        "_ = model.generate(\n",
        "    # 将文本输入转换为PyTorch张量，并移动到CUDA设备上。\n",
        "    **tokenizer(text, return_tensors = \"pt\").to(\"cuda\"),\n",
        "    # temperature=0 表示使用贪心解码，选择概率最高的token，生成结果更具确定性。\n",
        "    temperature = 0,\n",
        "    # 设置生成的最大新token数量。\n",
        "    max_new_tokens = 1024,\n",
        "    # 使用TextStreamer可以流式输出结果，即逐个token地打印生成内容，而不是等全部生成完再输出。\n",
        "    streamer = TextStreamer(tokenizer, skip_prompt = False),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YWSZ0DET7bob"
      },
      "outputs": [],
      "source": [
        "# --- GRPO 训练部分 --- #\n",
        "# 为了进行GRPO训练，我们首先清理内存，为加载新数据集做准备。\n",
        "\n",
        "# 删除之前的数据集对象\n",
        "del dataset\n",
        "# 清空PyTorch的CUDA缓存，释放未被引用的GPU内存\n",
        "torch.cuda.empty_cache()\n",
        "# 导入垃圾回收模块并执行垃圾回收\n",
        "import gc\n",
        "gc.collect()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o7-eUrQn-OzE"
      },
      "outputs": [],
      "source": [
        "# 为GRPO训练加载一个新的、更大的数学推理数据集\n",
        "from datasets import load_dataset\n",
        "dataset = load_dataset(\"open-r1/DAPO-Math-17k-Processed\", \"en\", split = \"train\")\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "siopxjG8-ReF"
      },
      "outputs": [],
      "source": [
        "# 查看新数据集的一条样本的问题（prompt）\n",
        "dataset[0][\"prompt\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KGupRQqD-Wcf"
      },
      "outputs": [],
      "source": [
        "# 查看该样本对应的解决方案（solution）\n",
        "dataset[0][\"solution\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8JJGXKdJ-Zl_"
      },
      "outputs": [],
      "source": [
        "# 定义一个函数来提取答案。在这个数据集中，答案直接就是solution字段，\n",
        "# 但在其他数据集中（如GSM8K），答案可能被####标记包围，这个函数是为此类情况准备的。\n",
        "def extract_hash_answer(text):\n",
        "    # if \"####\" not in text: return None\n",
        "    # return text.split(\"####\")[1].strip()\n",
        "    return text\n",
        "\n",
        "extract_hash_answer(dataset[0][\"solution\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qyEVI972-d3n"
      },
      "outputs": [],
      "source": [
        "# 再次对新数据集进行格式化，以符合我们的对话模板。\n",
        "# 这次我们只准备system和user角色的消息，assistant部分将由模型在GRPO训练中生成。\n",
        "dataset = dataset.map(lambda x: {\n",
        "    \"prompt\" : [\n",
        "        {\"role\": \"system\", \"content\": system_prompt},\n",
        "        {\"role\": \"user\",   \"content\": x[\"prompt\"]},\n",
        "    ],\n",
        "    # 同时提取出标准答案，用于后续的奖励函数评估。\n",
        "    \"answer\": extract_hash_answer(x[\"solution\"]),\n",
        "})\n",
        "\n",
        "# 查看格式化后的第一条数据。\n",
        "dataset[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iQwjTjNz-gY_"
      },
      "outputs": [],
      "source": [
        "# --- 定义GRPO的奖励函数 --- #\n",
        "# 奖励函数是GRPO的核心，它评估模型生成的回答的质量，并返回一个分数。\n",
        "# 分数越高，表示生成的回答越好。\n",
        "\n",
        "import re\n",
        "\n",
        "# 奖励函数1: 精确格式匹配\n",
        "# 我们定义一个正则表达式来检查模型的输出是否严格遵循了 <end...><SOLUTION>...</SOLUTION> 的格式。\n",
        "\n",
        "# 首先，创建一个正则表达式片段，用于匹配可能存在或不存在的EOS（end-of-sentence）令牌。\n",
        "solution_end_regex = r\"</SOLUTION>[\\s]{0,}\" + \\\n",
        "    \"(?:\" + re.escape(tokenizer.eos_token) + \")?\"\n",
        "\n",
        "# 编译完整的正则表达式\n",
        "match_format = re.compile(\n",
        "    rf\"{reasoning_end}.*?\"\\\n",
        "    rf\"{solution_start}(.+?){solution_end_regex}\"\\\n",
        "    rf\"[\\s]{{0,}}$\",\n",
        "    flags = re.MULTILINE | re.DOTALL\n",
        ")\n",
        "match_format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ndzHnQ_6-jHt"
      },
      "outputs": [],
      "source": [
        "# 测试正则表达式能否成功从一个符合格式的字符串中提取出答案 '2'。\n",
        "match_format.findall(\n",
        "    \"Let me think!<end_working_out>\"\\\n",
        "    f\"<SOLUTION>\\n2\\n</SOLUTION>\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eRMDAzDk2x6t"
      },
      "outputs": [],
      "source": [
        "# 再次测试，即使答案前后有空格和换行符，也应该能正确提取。\n",
        "match_format.findall(\n",
        "    \"<start_working_out>Let me think!<end_working_out>\"\\\n",
        "    f\"<SOLUTION>  2  </SOLUTION>\\n\\n\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cUfHzCVx-nGK"
      },
      "outputs": [],
      "source": [
        "# 奖励函数2: 近似格式匹配\n",
        "# 这个函数不检查答案是否正确，只检查格式。它计算输出中包含了多少个我们定义的特殊标签。\n",
        "# 如果模型生成了所有必需的标签（每个一次），它会得到正分；否则会得到负分。\n",
        "def match_format_approximately(completions, **kwargs):\n",
        "    scores = []\n",
        "    for completion in completions:\n",
        "        score = 0\n",
        "        response = completion[0][\"content\"]\n",
        "        \n",
        "        # 计算每个标签出现的次数，如果恰好是1次，则加分，否则扣分。\n",
        "        # 无需奖励<start_working_out>，因为我们总是在提示中预先添加它。\n",
        "        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0\n",
        "        score += 0.5 if response.count(solution_start)  == 1 else -1.0\n",
        "        score += 0.5 if response.count(solution_end)    == 1 else -1.0\n",
        "        scores.append(score)\n",
        "    return scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hmtI_8gg-uIE"
      },
      "outputs": [],
      "source": [
        "# 奖励函数3: 检查答案是否正确（字符串匹配）\n",
        "# 这个函数使用我们之前定义的正则表达式提取模型给出的答案，并与标准答案进行比较。\n",
        "def check_answer(prompts, completions, answer, **kwargs):\n",
        "    # 获取用户的问题和模型生成的所有回答\n",
        "    question = prompts[0][-1][\"content\"]\n",
        "    responses = [completion[0][\"content\"] for completion in completions]\n",
        "\n",
        "    # 使用正则表达式从每个回答中提取出<SOLUTION>标签内的内容\n",
        "    extracted_responses = [\n",
        "        guess.group(1)\n",
        "        if (guess := match_format.search(r)) is not None else None \\\n",
        "        for r in responses\n",
        "    ]\n",
        "\n",
        "    scores = []\n",
        "    # 遍历每个提取出的答案和对应的标准答案\n",
        "    for guess, true_answer in zip(extracted_responses, answer):\n",
        "        score = 0\n",
        "        # 如果没有提取到答案（格式错误），给予重罚。\n",
        "        if guess is None:\n",
        "            scores.append(-2.0)\n",
        "            continue\n",
        "        # 如果答案完全正确，给予最高分！\n",
        "        if guess == true_answer:\n",
        "            score += 5.0\n",
        "        # 如果去除空格后答案正确，也给予较高分数。\n",
        "        elif guess.strip() == true_answer.strip():\n",
        "            score += 3.5\n",
        "        else:\n",
        "            # 对于数值答案，我们也可以奖励近似正确的答案。\n",
        "            # 如果答案在真实答案的±10%范围内，给予奖励。\n",
        "            try:\n",
        "                ratio = float(guess) / float(true_answer)\n",
        "                if   ratio >= 0.9 and ratio <= 1.1: score += 2.0\n",
        "                elif ratio >= 0.8 and ratio <= 1.2: score += 1.5\n",
        "                else: score -= 2.5 # 错误答案给予惩罚\n",
        "            except:\n",
        "                # 如果无法转换为浮点数，给予重罚。\n",
        "                score -= 4.5\n",
        "        scores.append(score)\n",
        "    return scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AVW0kL8q-wL5"
      },
      "outputs": [],
      "source": [
        "# 奖励函数4: 检查答案是否为数值（更宽松的数值匹配）\n",
        "# 定义一个新的正则表达式，专门用于从<SOLUTION>标签中提取数值。\n",
        "match_numbers = re.compile(\n",
        "    solution_start + r\".*?[\\s]{0,}([-]?[\\d\\.\\,]{1,})\",\n",
        "    flags = re.MULTILINE | re.DOTALL\n",
        ")\n",
        "# 测试正则表达式\n",
        "print(match_numbers.findall(\"<SOLUTION>  0.34  </SOLUTION>\"))\n",
        "print(match_numbers.findall(\"<SOLUTION>  123,456  </SOLUTION>\"))\n",
        "print(match_numbers.findall(\"<SOLUTION>  -0.234  </SOLUTION>\"))\n",
        "print(match_numbers.findall(\"<SOLUTION>17</SOLUTION>\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GjBFrttr-y1_"
      },
      "outputs": [],
      "source": [
        "# 定义全局变量，用于控制日志打印的频率。\n",
        "global PRINTED_TIMES\n",
        "PRINTED_TIMES = 0\n",
        "global PRINT_EVERY_STEPS\n",
        "PRINT_EVERY_STEPS = 5\n",
        "\n",
        "# 奖励函数4的实现\n",
        "def check_numbers(prompts, completions, answer, **kwargs):\n",
        "    question = prompts[0][-1][\"content\"]\n",
        "    responses = [completion[0][\"content\"] for completion in completions]\n",
        "\n",
        "    # 使用新的正则表达式提取数值\n",
        "    extracted_responses = [\n",
        "        guess.group(1)\n",
        "        if (guess := match_numbers.search(r)) is not None else None \\\n",
        "        for r in responses\n",
        "    ]\n",
        "\n",
        "    scores = []\n",
        "    # 为了便于调试，每隔几步打印一次问题、答案、模型响应和提取结果。\n",
        "    global PRINTED_TIMES\n",
        "    global PRINT_EVERY_STEPS\n",
        "    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:\n",
        "        print(\n",
        "            '*'*20 + f\"Question:\\n{question}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\"\n",
        "        )\n",
        "    PRINTED_TIMES += 1\n",
        "\n",
        "    for guess, true_answer in zip(extracted_responses, answer):\n",
        "        # 如果没有提取到数值，给予重罚。\n",
        "        if guess is None:\n",
        "            scores.append(-2.5)\n",
        "            continue\n",
        "        # 尝试将提取的字符串和标准答案都转换为浮点数进行比较。\n",
        "        try:\n",
        "            true_answer = float(true_answer.strip())\n",
        "            # 移除逗号，如 '123,456' -> '123456'\n",
        "            guess       = float(guess.strip().replace(\",\", \"\"))\n",
        "            # 如果数值完全相等，给予高分，否则给予惩罚。\n",
        "            scores.append(3.5 if guess == true_answer else -1.5)\n",
        "        except:\n",
        "            # 如果转换失败，不给分也不扣分。\n",
        "            scores.append(0)\n",
        "            continue\n",
        "    return scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6EgAi4Q5fGE-"
      },
      "outputs": [],
      "source": [
        "# 再次对数据集进行预处理，为GRPO训练做准备\n",
        "\n",
        "# 将数据集的prompt部分（system和user消息）分词，并存储token\n",
        "tokenized = dataset.map(\n",
        "    lambda x: {\"tokens\" : tokenizer.apply_chat_template(x[\"prompt\"], add_generation_prompt = True, tokenize = True)},\n",
        "    batched = True,\n",
        ")\n",
        "# 打印第一条数据分词后解码的结果，进行验证\n",
        "print(tokenizer.decode(tokenized[0][\"tokens\"]))\n",
        "# 计算每条prompt的token长度\n",
        "tokenized = tokenized.map(lambda x: {\"L\" : len(x[\"tokens\"])})\n",
        "\n",
        "import numpy as np\n",
        "# 计算所有prompt长度的90%分位数，以此作为最大prompt长度的参考。\n",
        "# 这样可以过滤掉极少数过长的prompt，使训练更稳定高效。\n",
        "maximum_length = int(np.quantile(tokenized[\"L\"], 0.9))\n",
        "print(\"Max Length = \", maximum_length)\n",
        "\n",
        "# 根据计算出的最大长度，过滤数据集。\n",
        "dataset = dataset.select(np.where(np.array(tokenized[\"L\"]) <= maximum_length)[0])\n",
        "del tokenized"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ptqkXK2D4d6p"
      },
      "outputs": [],
      "source": [
        "# 设置GRPO训练的参数\n",
        "\n",
        "# prompt的最大长度，留一个token的余量\n",
        "max_prompt_length = maximum_length + 1\n",
        "# completion（模型生成部分）的最大长度\n",
        "max_completion_length = max_seq_length - max_prompt_length\n",
        "\n",
        "from vllm import SamplingParams\n",
        "# 配置vLLM的采样参数，这些参数在GRPO训练中用于从模型生成多个不同的回答。\n",
        "vllm_sampling_params = SamplingParams(\n",
        "    min_p = 0.1, # 最小概率采样（Min-P），忽略概率低于此值的token\n",
        "    top_p = 1.0, # Top-P（核）采样，从累积概率超过p的最小token集合中采样\n",
        "    top_k = -1,  # Top-K采样，-1表示不启用\n",
        "    seed = 3407, # 采样种子，保证可复现性\n",
        "    stop = [tokenizer.eos_token], # 遇到EOS令牌时停止生成\n",
        "    include_stop_str_in_output = True, # 在输出中包含停止符\n",
        ")\n",
        "\n",
        "from trl import GRPOConfig, GRPOTrainer\n",
        "# 配置GRPO训练参数\n",
        "training_args = GRPOConfig(\n",
        "    vllm_sampling_params = vllm_sampling_params, # 传入vLLM采样参数\n",
        "    temperature = 1.0, # 生成时的温度，值越高，随机性越强\n",
        "    learning_rate = 5e-6, # GRPO的学习率通常比SFT小\n",
        "    weight_decay = 0.01,\n",
        "    warmup_ratio = 0.1, # 预热步数占总步数的比例\n",
        "    lr_scheduler_type = \"linear\",\n",
        "    optim = \"adamw_8bit\",\n",
        "    logging_steps = 1, # 每一步都记录日志，便于观察\n",
        "    per_device_train_batch_size = 1, # GRPO中，这个值通常会被num_generations覆盖\n",
        "    gradient_accumulation_steps = 1, # 梯度累积，可以增加到4以获得更平滑的训练\n",
        "    num_generations = 4, # 每个prompt生成4个不同的回答进行评估，如果显存不足可以减小此值\n",
        "    max_prompt_length = max_prompt_length, # prompt最大长度\n",
        "    max_completion_length = max_completion_length, # 生成内容最大长度\n",
        "    # num_train_epochs = 1, # 训练总轮数\n",
        "    max_steps = 100, # 为了快速演示，这里只训练100步\n",
        "    save_steps = 100, # 每100步保存一次模型\n",
        "    report_to = \"swanlab\", # 日志报告平台\n",
        "    output_dir = \"outputs\", # 模型和日志输出目录\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vzOuSVCL_GA9"
      },
      "outputs": [],
      "source": [
        "# 实例化GRPOTrainer\n",
        "trainer = GRPOTrainer(\n",
        "    model = model, # 我们的基础模型\n",
        "    processing_class = tokenizer, # 分词器\n",
        "    reward_funcs = [ # 传入我们定义的所有奖励函数\n",
        "        # match_format_exactly, # 这个函数被注释掉了，因为它太严格了\n",
        "        match_format_approximately,\n",
        "        check_answer,\n",
        "        check_numbers,\n",
        "    ],\n",
        "    args = training_args, # 传入训练配置\n",
        "    train_dataset = dataset, # 训练数据集\n",
        ")\n",
        "# 开始GRPO训练\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qtcz_lpbVC92"
      },
      "outputs": [],
      "source": [
        "# GRPO训练后进行推理测试\n",
        "text = \"What is the sqrt of 101?\"\n",
        "\n",
        "from vllm import SamplingParams\n",
        "# 定义推理时的采样参数\n",
        "sampling_params = SamplingParams(\n",
        "    temperature = 1.0,\n",
        "    top_k = 50,\n",
        "    max_tokens = 1024,\n",
        ")\n",
        "\n",
        "# 使用Unsloth的fast_generate方法进行快速推理\n",
        "output = model.fast_generate(\n",
        "    [text],\n",
        "    sampling_params = sampling_params,\n",
        "    lora_request = None, # 因为LoRA权重已经合并到模型中，所以这里是None\n",
        ")[0].outputs[0].text\n",
        "\n",
        "output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AL-BcuB1VLIv"
      },
      "outputs": [],
      "source": [
        "# 将训练好的LoRA适配器权重保存到本地目录\n",
        "model.save_lora(\"grpo_saved_lora\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4SfdI-ERbpiw"
      },
      "outputs": [],
      "source": [
        "# 这是一个验证步骤，确保我们保存的LoRA权重是有效的。\n",
        "from safetensors import safe_open\n",
        "\n",
        "tensors = {}\n",
        "# 使用safetensors库安全地打开保存的适配器模型文件\n",
        "with safe_open(\"grpo_saved_lora/adapter_model.safetensors\", framework = \"pt\") as f:\n",
        "    # 遍历文件中的所有张量（tensors）\n",
        "    for key in f.keys():\n",
        "        tensor = f.get_tensor(key)\n",
        "        # 计算张量中零元素的比例\n",
        "        n_zeros = (tensor == 0).sum() / tensor.numel()\n",
        "        # 断言：确保张量不全为零，证明训练确实改变了权重。\n",
        "        assert(n_zeros.item() != tensor.numel())\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zf_OY5WMVOxF"
      },
      "outputs": [],
      "source": [
        "# 演示如何加载已保存的LoRA适配器进行推理\n",
        "messages = [\n",
        "    {\"role\": \"system\", \"content\": system_prompt},\n",
        "    {\"role\": \"user\",   \"content\": \"What is the sqrt of 101?\"},\n",
        "]\n",
        "\n",
        "# 同样，先应用聊天模板\n",
        "text = tokenizer.apply_chat_template(\n",
        "    messages,\n",
        "    add_generation_prompt = True,\n",
        "    tokenize = False,\n",
        ")\n",
        "\n",
        "from vllm import SamplingParams\n",
        "sampling_params = SamplingParams(\n",
        "    temperature = 1.0,\n",
        "    top_k = 50,\n",
        "    max_tokens = 2048,\n",
        ")\n",
        "\n",
        "# 在fast_generate中，通过lora_request参数加载指定的LoRA适配器\n",
        "output = model.fast_generate(\n",
        "    text,\n",
        "    sampling_params = sampling_params,\n",
        "    lora_request = model.load_lora(\"grpo_saved_lora\"),\n",
        ")[0].outputs[0].text\n",
        "\n",
        "output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QyEjW-WuYQIm"
      },
      "outputs": [],
      "source": [
        "# --- 模型保存与上传（GGUF格式） --- #\n",
        "# 以下代码块展示了如何将微调后的模型保存为GGUF格式，并推送到HuggingFace Hub。\n",
        "# GGUF是一种专为llama.cpp等推理框架设计的格式，支持量化，可以在CPU上高效运行。\n",
        "# 所有代码都被if False:包裹，表示默认不执行。如需使用，请将False改为True。\n",
        "\n",
        "# 保存为 8bit Q8_0 GGUF 格式（适用于GGUF量化模型推理）\n",
        "if False:\n",
        "    model.save_pretrained_gguf(\"model\", tokenizer)\n",
        "\n",
        "# 上传 Q8_0 量化模型到 HuggingFace Hub\n",
        "# 请前往 https://huggingface.co/settings/tokens 获取你的访问令牌（token）\n",
        "# 并将 \"hf\" 替换为你的HuggingFace用户名\n",
        "if False:\n",
        "    model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"YOUR_HF_TOKEN\")\n",
        "\n",
        "# 保存为 16bit GGUF 格式（即未量化版本，保留完整精度，适用于对精度要求高的任务）\n",
        "if False:\n",
        "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
        "\n",
        "# 上传 16bit GGUF 模型到 HuggingFace Hub\n",
        "if False:\n",
        "    model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"YOUR_HF_TOKEN\")\n",
        "\n",
        "# 保存为 q4_k_m（4位 K-type）GGUF 格式，这是一种在推理效率与模型精度之间取得良好平衡的常用量化方法。\n",
        "if False:\n",
        "    model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
        "\n",
        "# 上传 q4_k_m 模型到 HuggingFace Hub\n",
        "if False:\n",
        "    model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"YOUR_HF_TOKEN\")\n",
        "\n",
        "# 如果你想一次性上传多个不同量化精度的GGUF版本，可以使用以下方法，速度更快。\n",
        "if False:\n",
        "    model.push_to_hub_gguf(\n",
        "        \"hf/model\",  # 同样，替换 \"hf\" 为你的HuggingFace用户名\n",
        "        tokenizer,\n",
        "        quantization_method=[\"q4_k_m\", \"q8_0\", \"q5_k_m\"],  # 可根据需要选择要上传的量化格式\n",
        "        token=\"YOUR_HF_TOKEN\"\n",
        "    )\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
